{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DADUIVuPIqYi"
      },
      "source": [
        "##### Copyright 2019 Google LLC"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "Cu1zPez8Ip1S"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nlIh_TPLI54s"
      },
      "source": [
        "# Graph regularization for document classification using natural graphs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pL9fF9FWI-Q1"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_mlp_cora\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/neural-structured-learning/blob/master/g3doc/tutorials/graph_keras_mlp_cora.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/neural-structured-learning/blob/master/g3doc/tutorials/graph_keras_mlp_cora.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jJlN8oxhNGto"
      },
      "source": [
        "## Overview\n",
        "\n",
        "Graph regularization is a specific technique under the broader paradigm of\n",
        "Neural Graph Learning\n",
        "([Bui et al., 2018](https://research.google/pubs/pub46568.pdf)). The core\n",
        "idea is to train neural network models with a graph-regularized objective,\n",
        "harnessing both labeled and unlabeled data.\n",
        "\n",
        "In this tutorial, we will explore the use of graph regularization to classify\n",
        "documents that form a natural (organic) graph.\n",
        "\n",
        "The general recipe for creating a graph-regularized model using the Neural\n",
        "Structured Learning (NSL) framework is as follows:\n",
        "\n",
        "1.  Generate training data from the input graph and sample features. Nodes in\n",
        "    the graph correspond to samples and edges in the graph correspond to\n",
        "    similarity between pairs of samples. The resulting training data will\n",
        "    contain neighbor features in addition to the original node features.\n",
        "2.  Create a neural network as a base model using the `Keras` sequential,\n",
        "    functional, or subclass API.\n",
        "3.  Wrap the base model with the **`GraphRegularization`** wrapper class, which\n",
        "    is provided by the NSL framework, to create a new graph `Keras` model. This\n",
        "    new model will include a graph regularization loss as the regularization\n",
        "    term in its training objective.\n",
        "4.  Train and evaluate the graph `Keras` model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nDOFbB34KY1R"
      },
      "source": [
        "## Setup\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hgSLDi0SyBuO"
      },
      "source": [
        "Install the Neural Structured Learning package."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uVnjPmOaQlnH"
      },
      "outputs": [],
      "source": [
        "!pip install --quiet neural-structured-learning"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0AiNrPJaX_Lb"
      },
      "source": [
        "## Dependencies and imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5sEamf-wZJkX"
      },
      "outputs": [],
      "source": [
        "import neural_structured_learning as nsl\n",
        "\n",
        "import tensorflow as tf\n",
        "\n",
        "# Resets notebook state\n",
        "tf.keras.backend.clear_session()\n",
        "\n",
        "print(\"Version: \", tf.__version__)\n",
        "print(\"Eager mode: \", tf.executing_eagerly())\n",
        "print(\n",
        "    \"GPU is\",\n",
        "    \"available\" if tf.config.list_physical_devices(\"GPU\") else \"NOT AVAILABLE\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RtbcGZ_N6Ll9"
      },
      "source": [
        "## Cora dataset\n",
        "\n",
        "The [Cora dataset](https://linqs.soe.ucsc.edu/data) is a citation graph where\n",
        "nodes represent machine learning papers and edges represent citations between\n",
        "pairs of papers. The task involved is document classification where the goal is\n",
        "to categorize each paper into one of 7 categories. In other words, this is a\n",
        "multi-class classification problem with 7 classes.\n",
        "\n",
        "### Graph\n",
        "\n",
        "The original graph is directed. However, for the purpose of this example, we\n",
        "consider the undirected version of this graph. So, if paper A cites paper B, we\n",
        "also consider paper B to have cited A. Although this is not necessarily true, in\n",
        "this example, we consider citations as a proxy for similarity, which is usually\n",
        "a commutative property.\n",
        "\n",
        "### Features\n",
        "\n",
        "Each paper in the input effectively contains 2 features:\n",
        "\n",
        "1.  **Words**: A dense, multi-hot bag-of-words representation of the text in the\n",
        "    paper. The vocabulary for the Cora dataset contains 1433 unique words. So,\n",
        "    the length of this feature is 1433, and the value at position 'i' is 0/1\n",
        "    indicating whether word 'i' in the vocabulary exists in the given paper or\n",
        "    not.\n",
        "\n",
        "2.  **Label**: A single integer representing the class ID (category) of the paper."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p2-FVpVHEyIS"
      },
      "source": [
        "### Download the Cora dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2nSZjKqwE6Rn"
      },
      "outputs": [],
      "source": [
        "!wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz\n",
        "!tar -C /tmp -xvzf /tmp/cora.tgz"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ylYP32_IoqZI"
      },
      "source": [
        "### Convert the Cora data to the NSL format\n",
        "\n",
        "In order to preprocess the Cora dataset and convert it to the format required by\n",
        "Neural Structured Learning, we will run the **'preprocess_cora_dataset.py'**\n",
        "script, which is included in the NSL github repository. This script does the\n",
        "following:\n",
        "\n",
        "1.  Generate neighbor features using the original node features and the graph.\n",
        "2.  Generate train and test data splits containing `tf.train.Example` instances.\n",
        "3.  Persist the resulting train and test data in the `TFRecord` format."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Myz01LVZQ8Uh"
      },
      "outputs": [],
      "source": [
        "!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py\n",
        "\n",
        "!python preprocess_cora_dataset.py \\\n",
        "--input_cora_content=/tmp/cora/cora.content \\\n",
        "--input_cora_graph=/tmp/cora/cora.cites \\\n",
        "--max_nbrs=5 \\\n",
        "--output_train_data=/tmp/cora/train_merged_examples.tfr \\\n",
        "--output_test_data=/tmp/cora/test_examples.tfr"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DXoyHIdV5hEe"
      },
      "source": [
        "## Global variables\n",
        "\n",
        "The file paths to the train and test data are based on the command line flag\n",
        "values used to invoke the **'preprocess_cora_dataset.py'** script above."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kKAmzKIH6I9f"
      },
      "outputs": [],
      "source": [
        "### Experiment dataset\n",
        "TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'\n",
        "TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'\n",
        "\n",
        "### Constants used to identify neighbor features in the input.\n",
        "NBR_FEATURE_PREFIX = 'NL_nbr_'\n",
        "NBR_WEIGHT_SUFFIX = '_weight'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2gYWAqJqZ76I"
      },
      "source": [
        "## Hyperparameters\n",
        "\n",
        "We will use an instance of `HParams` to include various hyperparameters and\n",
        "constants used for training and evaluation. We briefly describe each of them\n",
        "below:\n",
        "\n",
        "-   **num_classes**: There are a total 7 different classes\n",
        "\n",
        "-   **max_seq_length**: This is the size of the vocabulary and all instances in\n",
        "    the input have a dense multi-hot, bag-of-words representation. In other\n",
        "    words, a value of 1 for a word indicates that the word is present in the\n",
        "    input and a value of 0 indicates that it is not.\n",
        "\n",
        "-   **distance_type**: This is the distance metric used to regularize the sample\n",
        "    with its neighbors.\n",
        "\n",
        "-   **graph_regularization_multiplier**: This controls the relative weight of\n",
        "    the graph regularization term in the overall loss function.\n",
        "\n",
        "-   **num_neighbors**: The number of neighbors used for graph regularization.\n",
        "    This value has to be less than or equal to the `max_nbrs` command-line\n",
        "    argument used above when running `preprocess_cora_dataset.py`.\n",
        "\n",
        "-   **num_fc_units**: The number of fully connected layers in our neural\n",
        "    network.\n",
        "\n",
        "-   **train_epochs**: The number of training epochs.\n",
        "\n",
        "-   **batch_size**: Batch size used for training and evaluation.\n",
        "\n",
        "-   **dropout_rate**: Controls the rate of dropout following each fully\n",
        "    connected layer\n",
        "\n",
        "-   **eval_steps**: The number of batches to process before deeming evaluation\n",
        "    is complete. If set to `None`, all instances in the test set are evaluated."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N03EFEkcOBaJ"
      },
      "outputs": [],
      "source": [
        "class HParams(object):\n",
        "  \"\"\"Hyperparameters used for training.\"\"\"\n",
        "  def __init__(self):\n",
        "    ### dataset parameters\n",
        "    self.num_classes = 7\n",
        "    self.max_seq_length = 1433\n",
        "    ### neural graph learning parameters\n",
        "    self.distance_type = nsl.configs.DistanceType.L2\n",
        "    self.graph_regularization_multiplier = 0.1\n",
        "    self.num_neighbors = 1\n",
        "    ### model architecture\n",
        "    self.num_fc_units = [50, 50]\n",
        "    ### training parameters\n",
        "    self.train_epochs = 100\n",
        "    self.batch_size = 128\n",
        "    self.dropout_rate = 0.5\n",
        "    ### eval parameters\n",
        "    self.eval_steps = None  # All instances in the test set are evaluated.\n",
        "\n",
        "HPARAMS = HParams()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IMp34x0MfMMZ"
      },
      "source": [
        "## Load train and test data\n",
        "\n",
        "As described earlier in this notebook, the input training and test data have\n",
        "been created by the **'preprocess_cora_dataset.py'**. We will load them into two\n",
        "`tf.data.Dataset` objects -- one for train and one for test.\n",
        "\n",
        "In the input layer of our model, we will extract not just the 'words' and the\n",
        "'label' features from each sample, but also corresponding neighbor features\n",
        "based on the `hparams.num_neighbors` value. Instances with fewer neighbors than\n",
        "`hparams.num_neighbors` will be assigned dummy values for those non-existent\n",
        "neighbor features."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LCKQVKGee1ST"
      },
      "outputs": [],
      "source": [
        "def make_dataset(file_path, training=False):\n",
        "  \"\"\"Creates a `tf.data.TFRecordDataset`.\n",
        "\n",
        "  Args:\n",
        "    file_path: Name of the file in the `.tfrecord` format containing\n",
        "      `tf.train.Example` objects.\n",
        "    training: Boolean indicating if we are in training mode.\n",
        "\n",
        "  Returns:\n",
        "    An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`\n",
        "    objects.\n",
        "  \"\"\"\n",
        "\n",
        "  def parse_example(example_proto):\n",
        "    \"\"\"Extracts relevant fields from the `example_proto`.\n",
        "\n",
        "    Args:\n",
        "      example_proto: An instance of `tf.train.Example`.\n",
        "\n",
        "    Returns:\n",
        "      A pair whose first value is a dictionary containing relevant features\n",
        "      and whose second value contains the ground truth label.\n",
        "    \"\"\"\n",
        "    # The 'words' feature is a multi-hot, bag-of-words representation of the\n",
        "    # original raw text. A default value is required for examples that don't\n",
        "    # have the feature.\n",
        "    feature_spec = {\n",
        "        'words':\n",
        "            tf.io.FixedLenFeature([HPARAMS.max_seq_length],\n",
        "                                  tf.int64,\n",
        "                                  default_value=tf.constant(\n",
        "                                      0,\n",
        "                                      dtype=tf.int64,\n",
        "                                      shape=[HPARAMS.max_seq_length])),\n",
        "        'label':\n",
        "            tf.io.FixedLenFeature((), tf.int64, default_value=-1),\n",
        "    }\n",
        "    # We also extract corresponding neighbor features in a similar manner to\n",
        "    # the features above during training.\n",
        "    if training:\n",
        "      for i in range(HPARAMS.num_neighbors):\n",
        "        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')\n",
        "        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,\n",
        "                                         NBR_WEIGHT_SUFFIX)\n",
        "        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(\n",
        "            [HPARAMS.max_seq_length],\n",
        "            tf.int64,\n",
        "            default_value=tf.constant(\n",
        "                0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))\n",
        "\n",
        "        # We assign a default value of 0.0 for the neighbor weight so that\n",
        "        # graph regularization is done on samples based on their exact number\n",
        "        # of neighbors. In other words, non-existent neighbors are discounted.\n",
        "        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(\n",
        "            [1], tf.float32, default_value=tf.constant([0.0]))\n",
        "\n",
        "    features = tf.io.parse_single_example(example_proto, feature_spec)\n",
        "\n",
        "    label = features.pop('label')\n",
        "    return features, label\n",
        "\n",
        "  dataset = tf.data.TFRecordDataset([file_path])\n",
        "  if training:\n",
        "    dataset = dataset.shuffle(10000)\n",
        "  dataset = dataset.map(parse_example)\n",
        "  dataset = dataset.batch(HPARAMS.batch_size)\n",
        "  return dataset\n",
        "\n",
        "\n",
        "train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)\n",
        "test_dataset = make_dataset(TEST_DATA_PATH)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hEWEGhtVzI2p"
      },
      "source": [
        "Let's peek into the train dataset to look at its contents."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gx-YFaBoCOcl"
      },
      "outputs": [],
      "source": [
        "for feature_batch, label_batch in train_dataset.take(1):\n",
        "  print('Feature list:', list(feature_batch.keys()))\n",
        "  print('Batch of inputs:', feature_batch['words'])\n",
        "  nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')\n",
        "  nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)\n",
        "  print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])\n",
        "  print('Batch of neighbor weights:',\n",
        "        tf.reshape(feature_batch[nbr_weight_key], [-1]))\n",
        "  print('Batch of labels:', label_batch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J7B30hRPzOBE"
      },
      "source": [
        "Let's peek into the test dataset to look at its contents."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kNJuM9yJiFtj"
      },
      "outputs": [],
      "source": [
        "for feature_batch, label_batch in test_dataset.take(1):\n",
        "  print('Feature list:', list(feature_batch.keys()))\n",
        "  print('Batch of inputs:', feature_batch['words'])\n",
        "  print('Batch of labels:', label_batch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OZhsTo8yio8i"
      },
      "source": [
        "## Model definition\n",
        "\n",
        "In order to demonstrate the use of graph regularization, we build a base model\n",
        "for this problem first. We will use a simple feed-forward neural network with 2\n",
        "hidden layers and dropout in between. We illustrate the creation of the base\n",
        "model using all model types supported by the `tf.Keras` framework -- sequential,\n",
        "functional, and subclass."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z_kBDDfFiuyI"
      },
      "source": [
        "### Sequential base model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pecJmegfWijx"
      },
      "outputs": [],
      "source": [
        "def make_mlp_sequential_model(hparams):\n",
        "  \"\"\"Creates a sequential multi-layer perceptron model.\"\"\"\n",
        "  model = tf.keras.Sequential()\n",
        "  model.add(\n",
        "      tf.keras.layers.InputLayer(\n",
        "          input_shape=(hparams.max_seq_length,), name='words'))\n",
        "  # Input is already one-hot encoded in the integer format. We cast it to\n",
        "  # floating point format here.\n",
        "  model.add(\n",
        "      tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))\n",
        "  for num_units in hparams.num_fc_units:\n",
        "    model.add(tf.keras.layers.Dense(num_units, activation='relu'))\n",
        "    # For sequential models, by default, Keras ensures that the 'dropout' layer\n",
        "    # is invoked only during training.\n",
        "    model.add(tf.keras.layers.Dropout(hparams.dropout_rate))\n",
        "  model.add(tf.keras.layers.Dense(hparams.num_classes, activation='softmax'))\n",
        "  return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FfZWxqVPiz_f"
      },
      "source": [
        "### Functional base model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TU-YE4NXi7PK"
      },
      "outputs": [],
      "source": [
        "def make_mlp_functional_model(hparams):\n",
        "  \"\"\"Creates a functional API-based multi-layer perceptron model.\"\"\"\n",
        "  inputs = tf.keras.Input(\n",
        "      shape=(hparams.max_seq_length,), dtype='int64', name='words')\n",
        "\n",
        "  # Input is already one-hot encoded in the integer format. We cast it to\n",
        "  # floating point format here.\n",
        "  cur_layer = tf.keras.layers.Lambda(\n",
        "      lambda x: tf.keras.backend.cast(x, tf.float32))(\n",
        "          inputs)\n",
        "\n",
        "  for num_units in hparams.num_fc_units:\n",
        "    cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)\n",
        "    # For functional models, by default, Keras ensures that the 'dropout' layer\n",
        "    # is invoked only during training.\n",
        "    cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)\n",
        "\n",
        "  outputs = tf.keras.layers.Dense(\n",
        "      hparams.num_classes, activation='softmax')(\n",
        "          cur_layer)\n",
        "\n",
        "  model = tf.keras.Model(inputs, outputs=outputs)\n",
        "  return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8LmAhITRi8M0"
      },
      "source": [
        "### Subclass base model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4l1aK9b_jAw6"
      },
      "outputs": [],
      "source": [
        "def make_mlp_subclass_model(hparams):\n",
        "  \"\"\"Creates a multi-layer perceptron subclass model in Keras.\"\"\"\n",
        "\n",
        "  class MLP(tf.keras.Model):\n",
        "    \"\"\"Subclass model defining a multi-layer perceptron.\"\"\"\n",
        "\n",
        "    def __init__(self):\n",
        "      super(MLP, self).__init__()\n",
        "      # Input is already one-hot encoded in the integer format. We create a\n",
        "      # layer to cast it to floating point format here.\n",
        "      self.cast_to_float_layer = tf.keras.layers.Lambda(\n",
        "          lambda x: tf.keras.backend.cast(x, tf.float32))\n",
        "      self.dense_layers = [\n",
        "          tf.keras.layers.Dense(num_units, activation='relu')\n",
        "          for num_units in hparams.num_fc_units\n",
        "      ]\n",
        "      self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)\n",
        "      self.output_layer = tf.keras.layers.Dense(\n",
        "          hparams.num_classes, activation='softmax')\n",
        "\n",
        "    def call(self, inputs, training=False):\n",
        "      cur_layer = self.cast_to_float_layer(inputs['words'])\n",
        "      for dense_layer in self.dense_layers:\n",
        "        cur_layer = dense_layer(cur_layer)\n",
        "        cur_layer = self.dropout_layer(cur_layer, training=training)\n",
        "\n",
        "      outputs = self.output_layer(cur_layer)\n",
        "\n",
        "      return outputs\n",
        "\n",
        "  return MLP()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BbGpIbscjIAo"
      },
      "source": [
        "## Create base model(s)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fzMBxiMGjCO4"
      },
      "outputs": [],
      "source": [
        "# Create a base MLP model using the functional API.\n",
        "# Alternatively, you can also create a sequential or subclass base model using\n",
        "# the make_mlp_sequential_model() or make_mlp_subclass_model() functions\n",
        "# respectively, defined above. Note that if a subclass model is used, its\n",
        "# summary cannot be generated until it is built.\n",
        "base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)\n",
        "base_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T1uEboimjiW7"
      },
      "source": [
        "## Train base MLP model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JALapM4PoCvi"
      },
      "outputs": [],
      "source": [
        "# Compile and train the base MLP model\n",
        "base_model.compile(\n",
        "    optimizer='adam',\n",
        "    loss='sparse_categorical_crossentropy',\n",
        "    metrics=['accuracy'])\n",
        "base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gPRioqydQD_8"
      },
      "source": [
        "## Evaluate base MLP model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2NcsJVt6FSmZ"
      },
      "outputs": [],
      "source": [
        "# Helper function to print evaluation metrics.\n",
        "def print_metrics(model_desc, eval_metrics):\n",
        "  \"\"\"Prints evaluation metrics.\n",
        "\n",
        "  Args:\n",
        "    model_desc: A description of the model.\n",
        "    eval_metrics: A dictionary mapping metric names to corresponding values. It\n",
        "      must contain the loss and accuracy metrics.\n",
        "  \"\"\"\n",
        "  print('\\n')\n",
        "  print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])\n",
        "  print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])\n",
        "  if 'graph_loss' in eval_metrics:\n",
        "    print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-myfttwIQAwc"
      },
      "outputs": [],
      "source": [
        "eval_results = dict(\n",
        "    zip(base_model.metrics_names,\n",
        "        base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))\n",
        "print_metrics('Base MLP model', eval_results)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bGwSzS9Spaiu"
      },
      "source": [
        "## Train MLP model with graph regularization\n",
        "\n",
        "Incorporating graph regularization into the loss term of an existing\n",
        "`tf.Keras.Model` requires just a few lines of code. The base model is wrapped to\n",
        "create a new `tf.Keras` subclass model, whose loss includes graph\n",
        "regularization."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vuIGN8KQH0jR"
      },
      "source": [
        "To assess the incremental benefit of graph regularization, we will create a new\n",
        "base model instance. This is because `base_model` has already been trained for a\n",
        "few iterations, and reusing this trained model to create a graph-regularized\n",
        "model will not be a fair comparison for `base_model`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6fvLei9dLCH0"
      },
      "outputs": [],
      "source": [
        "# Build a new base MLP model.\n",
        "base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(\n",
        "    HPARAMS)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HT3mpC8Lo1UZ"
      },
      "outputs": [],
      "source": [
        "# Wrap the base MLP model with graph regularization.\n",
        "graph_reg_config = nsl.configs.make_graph_reg_config(\n",
        "    max_neighbors=HPARAMS.num_neighbors,\n",
        "    multiplier=HPARAMS.graph_regularization_multiplier,\n",
        "    distance_type=HPARAMS.distance_type,\n",
        "    sum_over_axis=-1)\n",
        "graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,\n",
        "                                                graph_reg_config)\n",
        "graph_reg_model.compile(\n",
        "    optimizer='adam',\n",
        "    loss='sparse_categorical_crossentropy',\n",
        "    metrics=['accuracy'])\n",
        "graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6409ylRVQS7q"
      },
      "source": [
        "## Evaluate MLP model with graph regularization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1TsOE1bAQTqD"
      },
      "outputs": [],
      "source": [
        "eval_results = dict(\n",
        "    zip(graph_reg_model.metrics_names,\n",
        "        graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))\n",
        "print_metrics('MLP + graph regularization', eval_results)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Adc-r84EOSQi"
      },
      "source": [
        "The graph-regularized model's accuracy is about 2-3% higher than that of the\n",
        "base model (`base_model`)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OEXQHFNvJQJe"
      },
      "source": [
        "## Conclusion\n",
        "\n",
        "We have demonstrated the use of graph regularization for document classification\n",
        "on a natural citation graph (Cora) using the Neural Structured Learning (NSL)\n",
        "framework. Our [advanced tutorial](graph_keras_lstm_imdb.ipynb) involves\n",
        "synthesizing graphs based on sample embeddings before training a neural network\n",
        "with graph regularization. This approach is useful if the input does not contain\n",
        "an explicit graph.\n",
        "\n",
        "We encourage users to experiment further by varying the amount of supervision as\n",
        "well as trying different neural architectures for graph regularization."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "graph_keras_mlp_cora.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
